{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a652d759",
   "metadata": {},
   "source": [
    "# {class}`~tvm.relax.training.SetupTrainer`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "171a84fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "\n",
    "from tvm import relax, TVMError\n",
    "from tvm.ir.base import assert_structural_equal\n",
    "from tvm.relax.training import SetupTrainer\n",
    "from tvm.relax.training.optimizer import SGD, MomentumSGD\n",
    "from tvm.relax.training.loss import MSELoss\n",
    "from tvm.script import ir as I, relax as R"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2509f131",
   "metadata": {},
   "source": [
    "## 测试简单模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c531b102",
   "metadata": {},
   "outputs": [],
   "source": [
    "@I.ir_module\n",
    "class Backbone:\n",
    "    I.module_attrs({\"param_num\": 1, \"state_num\": 0})\n",
    "    @R.function\n",
    "    def backbone(x: R.Tensor((2, 2), \"float64\"), y: R.Tensor((2, 2), \"float64\")):\n",
    "        with R.dataflow():\n",
    "            x1 = x + y\n",
    "            R.output(x1)\n",
    "        return x1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bae8c8db",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    I<span style=\"color: #AA22FF; font-weight: bold\">.</span>module_attrs({<span style=\"color: #BA2121\">&quot;input_num&quot;</span>: <span style=\"color: #008000\">1</span>, <span style=\"color: #BA2121\">&quot;optim_state&quot;</span>: [metadata[<span style=\"color: #BA2121\">&quot;runtime.NDArray&quot;</span>][<span style=\"color: #008000\">0</span>]], <span style=\"color: #BA2121\">&quot;param_num&quot;</span>: <span style=\"color: #008000\">1</span>, <span style=\"color: #BA2121\">&quot;state_num&quot;</span>: <span style=\"color: #008000\">0</span>})\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">backbone</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            x1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(x1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> x1\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">backbone_loss</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), targets: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            x1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>subtract(x1, targets)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv, lv)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(lv1, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">backbone_loss_adjoint</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), targets: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>))):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            x1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>subtract(x1, targets)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv, lv)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(lv1, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            gv_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>ones(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>shape([]), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)\n",
       "            lv1_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>broadcast_to(gv_adjoint, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>shape([<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>]))\n",
       "            lv_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv1_adjoint, lv)\n",
       "            lv_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv1_adjoint, lv)\n",
       "            lv_adjoint1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(lv_adjoint, lv_1)\n",
       "            x1_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> lv_adjoint1\n",
       "            y_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> x1_adjoint\n",
       "            y_adjoint_out: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> y_adjoint\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv, y_adjoint_out)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (gv, (y_adjoint_out,))\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">optimizer</span>(params: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)), gradients: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)), optim_states: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>))) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>))):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            num_steps: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> optim_states[<span style=\"color: #008000\">0</span>]\n",
       "            num_steps_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(num_steps, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">1</span>, <span style=\"color: #BA2121\">&quot;int64&quot;</span>))\n",
       "            y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> params[<span style=\"color: #008000\">0</span>]\n",
       "            y_grad: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> gradients[<span style=\"color: #008000\">0</span>]\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">0.10000000000000001</span>, <span style=\"color: #BA2121\">&quot;float64&quot;</span>), y_grad)\n",
       "            y_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>subtract(y, lv)\n",
       "            params_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">=</span> (y_new,)\n",
       "            optim_states_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">=</span> (num_steps_new,)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(params_new, optim_states_new)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (params_new, optim_states_new)\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sinfo = relax.TensorStructInfo((2, 2), \"float64\")\n",
    "setup_trainer = SetupTrainer(MSELoss(reduction=\"sum\"), SGD(0.1), [sinfo, sinfo], legalize=False)\n",
    "train_mod = setup_trainer(Backbone)\n",
    "train_mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f943883",
   "metadata": {},
   "source": [
    "## 测试状态"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ac90d4db",
   "metadata": {},
   "outputs": [],
   "source": [
    "@I.ir_module\n",
    "class Backbone:\n",
    "    I.module_attrs({\"param_num\": 1, \"state_num\": 1})\n",
    "    @R.function\n",
    "    def backbone(x: R.Tensor((2, 2), \"float64\"), y: R.Tensor((2, 2), \"float64\"), z: R.Tensor((2, 2), \"float64\")):\n",
    "        with R.dataflow():\n",
    "            x1 = x + y\n",
    "            z1 = z + R.const(1, \"float64\")\n",
    "            R.output(x1, z1)\n",
    "        return x1, z1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b7337343",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    I<span style=\"color: #AA22FF; font-weight: bold\">.</span>module_attrs({<span style=\"color: #BA2121\">&quot;input_num&quot;</span>: <span style=\"color: #008000\">1</span>, <span style=\"color: #BA2121\">&quot;optim_state&quot;</span>: [metadata[<span style=\"color: #BA2121\">&quot;runtime.NDArray&quot;</span>][<span style=\"color: #008000\">0</span>], metadata[<span style=\"color: #BA2121\">&quot;runtime.NDArray&quot;</span>][<span style=\"color: #008000\">1</span>]], <span style=\"color: #BA2121\">&quot;param_num&quot;</span>: <span style=\"color: #008000\">1</span>, <span style=\"color: #BA2121\">&quot;state_num&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">backbone</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), z: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            x1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            z1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(z, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">1.0</span>, <span style=\"color: #BA2121\">&quot;float64&quot;</span>))\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(x1, z1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (x1, z1)\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">backbone_loss</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), z: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), targets: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            x1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            z1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(z, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">1.0</span>, <span style=\"color: #BA2121\">&quot;float64&quot;</span>))\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>subtract(x1, targets)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv, lv)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(lv1, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(z1, gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (gv, z1)\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">backbone_loss_adjoint</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), z: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), targets: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>))):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            x1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            z1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(z, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">1.0</span>, <span style=\"color: #BA2121\">&quot;float64&quot;</span>))\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>subtract(x1, targets)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv, lv)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(lv1, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            gv_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>ones(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>shape([]), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)\n",
       "            lv1_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>broadcast_to(gv_adjoint, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>shape([<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>]))\n",
       "            lv_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv1_adjoint, lv)\n",
       "            lv_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(lv1_adjoint, lv)\n",
       "            lv_adjoint1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(lv_adjoint, lv_1)\n",
       "            x1_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> lv_adjoint1\n",
       "            y_adjoint: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> x1_adjoint\n",
       "            y_adjoint_out: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> y_adjoint\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(z1, gv, y_adjoint_out)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> ((gv, z1), (y_adjoint_out,))\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">optimizer</span>(params: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)), gradients: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)), optim_states: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>))) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>))):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            num_steps: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> optim_states[<span style=\"color: #008000\">0</span>]\n",
       "            num_steps_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(num_steps, R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">1</span>, <span style=\"color: #BA2121\">&quot;int64&quot;</span>))\n",
       "            y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> params[<span style=\"color: #008000\">0</span>]\n",
       "            y_grad: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> gradients[<span style=\"color: #008000\">0</span>]\n",
       "            y_v: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> optim_states[<span style=\"color: #008000\">1</span>]\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">0.10000000000000001</span>, <span style=\"color: #BA2121\">&quot;float64&quot;</span>), y_v)\n",
       "            y_v_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(lv, y_grad)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>multiply(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>const(<span style=\"color: #008000\">0.10000000000000001</span>, <span style=\"color: #BA2121\">&quot;float64&quot;</span>), y_v_new)\n",
       "            y_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>subtract(y, lv1)\n",
       "            params_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">=</span> (y_new,)\n",
       "            optim_states_new: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;int64&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float64&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">=</span> num_steps_new, y_v_new\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(params_new, optim_states_new)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (params_new, optim_states_new)\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sinfo = relax.TensorStructInfo((2, 2), \"float64\")\n",
    "setup_trainer = SetupTrainer(\n",
    "    MSELoss(reduction=\"sum\"), MomentumSGD(0.1, 0.1), [sinfo, sinfo], legalize=False\n",
    ")\n",
    "train_mod = setup_trainer(Backbone)\n",
    "train_mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aab9cef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
